import logging
import time
from concurrent.futures import ThreadPoolExecutor

import cv2
import numpy as np
import pymysql
import torch.cuda
from PIL import Image
from flask import jsonify
from sqlalchemy import create_engine, update
from sqlalchemy.orm import sessionmaker, scoped_session
from werkzeug.exceptions import BadRequest

import config
from app.model.models import DeepfakeVideo
from app.recognition import DeepfakeDetectionSIFDNet
from app.recognition.face_detector import FacenetDetector
from app.utils import file_utils, metrics_utils
from app.utils.image_utils import crop_and_resize, ndarray_error_threshold
from . import bp

LOG = logging.getLogger(__name__)

engine = create_engine(config.get_config('SQLALCHEMY_DATABASE_URI'))
session_factory = sessionmaker(bind=engine)
# The Session object created here will be used by the function f1 directly.
Session = scoped_session(session_factory)

pool = ThreadPoolExecutor(max_workers=1)
# loop = asyncio.get_event_loop()

# FPS = 25
FACE_DETECTOR = FacenetDetector(landmarks=False)
DEEPFAKE_DETECTOR = DeepfakeDetectionSIFDNet.instance()


def exception_callback(future):
    exception = future.exception()
    if exception:
        LOG.error(exception)


def run_detection(video_entity, face_detector, deepfake_detector):
    session = Session()
    try:
        # Init status=1 and progress=0
        session.execute(
            update(DeepfakeVideo)
            .where(DeepfakeVideo.id == video_entity.id)
            .values(status=1, progress=0)
        )
        session.commit()
        video_path = str(file_utils.VIDEO_FOLDER / video_entity.video_name)
        frames, fps = file_utils.get_video_frames_and_fps(video_path)

        mask_frames = []
        fake_probs = []

        for index, frame in frames.items():
            # if index > 50:
            #     break

            # Update execution progress
            if (index + 1) % fps == 0:
                progress = int((index + 1) / len(frames) * 100)
                session.execute(
                    update(DeepfakeVideo)
                    .where(DeepfakeVideo.id == video_entity.id)
                    .values(progress=progress)
                )
                session.commit()
                if not torch.cuda.is_available():
                    time.sleep(0.5)

            # Init mask frame
            mask_frame = np.zeros_like(frame)

            # Get bounding boxes through face detection
            resized_image = Image.fromarray(frame)
            resized_image = resized_image.resize(size=(resized_image.size[0] // 2, resized_image.size[1] // 2))
            bboxes = face_detector.detect_faces(resized_image)

            # Record fake probability of one frame
            frame_fake_prob_sum = 0.
            for bbox in bboxes:
                # Crop face
                xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox]
                w = xmax - xmin
                h = ymax - ymin
                p_h = h // 3
                p_w = w // 3
                crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w]
                fake_prob, mask_pred = deepfake_detector.get_predictions(crop)
                frame_fake_prob_sum += fake_prob

                # Fill the frame mask with the mask of manipulation region.
                mask_pred = ndarray_error_threshold(mask_pred)
                mask_pred = crop_and_resize(mask_pred, crop)
                mask_pred = mask_pred * 255
                mask_pred = np.expand_dims(mask_pred, axis=2)
                mask_pred = np.tile(mask_pred, (1, 1, 3))
                mask_frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] = mask_pred
            fake_probs.append(frame_fake_prob_sum / len(bboxes))
            mask_frames.append(mask_frame)

        # Update status=2, progress=100 and probability=?
        fake_probs = np.array(fake_probs)
        result = metrics_utils.get_video_fake_prob(fake_probs)
        session.execute(
            update(DeepfakeVideo)
            .where(DeepfakeVideo.id == video_entity.id)
            .values(status=2, progress=100, probability=result)
        )
        session.commit()

        # Save mask video
        shape = mask_frames[0].shape
        mask_video_path = str(file_utils.VIDEO_FOLDER / f'{video_entity.video_name[:-4]}_mask.webm')
        fourcc = cv2.VideoWriter_fourcc(*"vp80")
        video_tracked = cv2.VideoWriter(mask_video_path, fourcc, fps, (shape[1], shape[0]))
        for frame in mask_frames:
            video_tracked.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        video_tracked.release()
    except pymysql.err.OperationalError as e:
        LOG.info(e)
    except Exception as e:
        LOG.info(e)
        session.execute(
            update(DeepfakeVideo)
            .where(DeepfakeVideo.id == video_entity.id)
            .values(status=3, progress=0, probability=0.)
        )
        session.commit()
        return


@bp.route('/deepfake/detection/<path:video_id>', methods=['GET'])
def deepfake_detection(video_id):
    video_entity = DeepfakeVideo.query.filter_by(id=video_id).first()
    if not video_entity:
        raise BadRequest('Video does not exist.')
    status = video_entity.status
    if status == 1:
        raise BadRequest('Video is being analyzed.')
    running_count = DeepfakeVideo.query.filter_by(status=1).count()
    if running_count >= 1:
        raise BadRequest('The maximum number of tasks that can be run is 1. Please try again later.')

    try:
        future = pool.submit(run_detection, video_entity, FACE_DETECTOR, DEEPFAKE_DETECTOR)
        # future.add_done_callback(exception_callback)
    except Exception as e:
        LOG.error(e)
        return jsonify(success=False, data="执行失败，请稍后重试！", status_code=500)
    # task = run_detection(video_entity, db.session, DETECTOR, deepfake_detection_sifdnet)
    # loop.run_in_executor()
    return jsonify(success=True, data=None, status_code=200)


@bp.route('/detection/progress/<path:video_id>', methods=['GET'])
def detection_progress(video_id):
    detection_result = DeepfakeVideo.query.with_entities(
        DeepfakeVideo.status, DeepfakeVideo.progress, DeepfakeVideo.probability
    ).filter_by(id=video_id).first()
    return jsonify(success=True, data=detection_result._asdict(), status_code=200)
